{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TF-TRT Keras Classification Examples:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we cover a variety of classification base networks pulled from the tensorflow.keras.applications project!\n",
    "\n",
    "This demonstrates TF-TRT working on a variety of model architectures out of the box. This is a great way to demonstrate the ease of use of TF-TRT. TF-TRT can still optimize parts of your network even if it contains layers that are not supported by TensorRT itself. This makes it easy to get a first-pass at an optimized model - as we will demonstrate here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make sure our GPUs are properly configured and visible:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fri Jan 29 22:55:18 2021       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.1     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |\n",
      "| N/A   43C    P0    62W / 300W |    125MiB / 16155MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |\n",
      "| N/A   42C    P0    38W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  Tesla V100-DGXS...  On   | 00000000:0E:00.0 Off |                    0 |\n",
      "| N/A   41C    P0    38W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  Tesla V100-DGXS...  On   | 00000000:0F:00.0 Off |                    0 |\n",
      "| N/A   42C    P0    37W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remember to sucessfully deploy a TensorRT model, you have to make __five key decisions__:\n",
    "\n",
    "1. __What format should I save my model in?__\n",
    "2. __What batch size(s) am I running inference at?__\n",
    "3. __What precision am I running inference at?__\n",
    "4. __What TensorRT path am I using to convert my model?__\n",
    "5. __What runtime am I targeting?__\n",
    "\n",
    "Let's get to it!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. What format should I save my model in?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TF-TRT requires SavedModel format in Tensorflow 2.x:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p tmp_savedmodels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "downloading and initializing models...\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5\n",
      "553467904/553467096 [==============================] - 73s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels.h5\n",
      "96116736/96112376 [==============================] - 3s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels.h5\n",
      "91889664/91884032 [==============================] - 5s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224.h5\n",
      "14540800/14536120 [==============================] - 1s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/densenet/densenet121_weights_tf_dim_ordering_tf_kernels.h5\n",
      "33193984/33188688 [==============================] - 1s 0us/step\n",
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels.h5\n",
      "102875136/102869336 [==============================] - 10s 0us/step\n",
      "saving <tensorflow.python.keras.engine.functional.Functional object at 0x7f5cb859b0f0> ...\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/0/assets\n",
      "finished saving!\n",
      "saving <tensorflow.python.keras.engine.functional.Functional object at 0x7f5cb8519cc0> ...\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/1/assets\n",
      "finished saving!\n",
      "saving <tensorflow.python.keras.engine.functional.Functional object at 0x7f5cb0172208> ...\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/2/assets\n",
      "finished saving!\n",
      "saving <tensorflow.python.keras.engine.functional.Functional object at 0x7f5c343f57b8> ...\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/3/assets\n",
      "finished saving!\n",
      "saving <tensorflow.python.keras.engine.functional.Functional object at 0x7f5bdc4ed048> ...\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/4/assets\n",
      "finished saving!\n",
      "saving <tensorflow.python.keras.engine.functional.Functional object at 0x7f5b984a73c8> ...\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/5/assets\n",
      "finished saving!\n",
      "saving <tensorflow.python.keras.engine.functional.Functional object at 0x7f5b7856e2b0> ...\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/6/assets\n",
      "finished saving!\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.keras.applications import ResNet50, VGG16, InceptionV3, Xception, MobileNetV2, DenseNet121, ResNet50V2\n",
    "\n",
    "print(\"Downloading and initializing models...\")\n",
    "models = [ResNet50, VGG16, InceptionV3, Xception, MobileNetV2, DenseNet121, ResNet50V2]\n",
    "models = [model(include_top=True, weights='imagenet') for model in models]\n",
    "\n",
    "model_dirs = []\n",
    "for idx, model in enumerate(models):\n",
    "    print(\"Saving\", model,\"...\")\n",
    "    model_dir = 'tmp_savedmodels/%s' % idx\n",
    "    model_dirs.append(model_dir)\n",
    "    model.save(model_dir) \n",
    "    print(\"Finished saving!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. What batch size(s) am I running inference at?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use a batch size of 32 for all models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We create a series of randomized \"dummy\" batches to test our model on:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "dummy_input_batch = lambda x: np.zeros((BATCH_SIZE, x, x, 3))\n",
    "\n",
    "dummy_inputs = [224, 224, 299, 299, 224, 224, 224]\n",
    "dummy_inputs = [dummy_input_batch(size) for size in dummy_inputs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Last, we \"warm up\" all of our models so their one time start-up costs aren't throw off any of our Jupyter magic %%timeit timer calls:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f59d856e488> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n",
      "WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f5b3820d6a8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n",
      "WARNING:tensorflow:7 out of the last 7 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f596ca93950> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.\n"
     ]
    }
   ],
   "source": [
    "# Warm up:\n",
    "for idx, model in enumerate(models):\n",
    "    model.predict(dummy_inputs[idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. What precision am I running inference at?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will leave it as the default:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "PRECISION = \"FP32\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. What TensorRT tool or integration am I using to convert my model?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using TF-TRT through the ModelOptimizer example wrapper used in this guide:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting resnet50 tmp_savedmodels/0\n",
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/0_FP32/assets\n",
      "[[1.6964252e-04 3.3007402e-04 6.1350249e-05 ... 1.4622317e-05\n",
      "  1.4449877e-04 6.6086568e-04]\n",
      " [1.6964252e-04 3.3007402e-04 6.1350249e-05 ... 1.4622317e-05\n",
      "  1.4449877e-04 6.6086568e-04]\n",
      " [1.6964252e-04 3.3007402e-04 6.1350249e-05 ... 1.4622317e-05\n",
      "  1.4449877e-04 6.6086568e-04]\n",
      " ...\n",
      " [1.6964252e-04 3.3007402e-04 6.1350249e-05 ... 1.4622317e-05\n",
      "  1.4449877e-04 6.6086568e-04]\n",
      " [1.6964252e-04 3.3007402e-04 6.1350249e-05 ... 1.4622317e-05\n",
      "  1.4449877e-04 6.6086568e-04]\n",
      " [1.6964252e-04 3.3007402e-04 6.1350249e-05 ... 1.4622317e-05\n",
      "  1.4449877e-04 6.6086568e-04]]\n",
      "Finished!\n",
      "\n",
      "Starting vgg16 tmp_savedmodels/1\n",
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_1_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/1_FP32/assets\n",
      "[[0.00022801 0.00222478 0.00050746 ... 0.00011863 0.00026599 0.01312881]\n",
      " [0.00022801 0.00222478 0.00050746 ... 0.00011863 0.00026599 0.01312881]\n",
      " [0.00022801 0.00222478 0.00050746 ... 0.00011863 0.00026599 0.01312881]\n",
      " ...\n",
      " [0.00022801 0.00222478 0.00050746 ... 0.00011863 0.00026599 0.01312881]\n",
      " [0.00022801 0.00222478 0.00050746 ... 0.00011863 0.00026599 0.01312881]\n",
      " [0.00022801 0.00222478 0.00050746 ... 0.00011863 0.00026599 0.01312881]]\n",
      "Finished!\n",
      "\n",
      "Starting inception_v3 tmp_savedmodels/2\n",
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_2_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/2_FP32/assets\n",
      "[[0.00043102 0.00033233 0.0002535  ... 0.00012701 0.00023254 0.00082577]\n",
      " [0.00043102 0.00033233 0.0002535  ... 0.00012701 0.00023254 0.00082577]\n",
      " [0.00043102 0.00033233 0.0002535  ... 0.00012701 0.00023254 0.00082577]\n",
      " ...\n",
      " [0.00043102 0.00033233 0.0002535  ... 0.00012701 0.00023254 0.00082577]\n",
      " [0.00043102 0.00033233 0.0002535  ... 0.00012701 0.00023254 0.00082577]\n",
      " [0.00043102 0.00033233 0.0002535  ... 0.00012701 0.00023254 0.00082577]]\n",
      "Finished!\n",
      "\n",
      "Starting xception tmp_savedmodels/3\n",
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_3_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/3_FP32/assets\n",
      "[[0.00022673 0.00034859 0.00021873 ... 0.00012943 0.00032854 0.00086526]\n",
      " [0.00022673 0.00034859 0.00021873 ... 0.00012943 0.00032854 0.00086526]\n",
      " [0.00022673 0.00034859 0.00021873 ... 0.00012943 0.00032854 0.00086526]\n",
      " ...\n",
      " [0.00022673 0.00034859 0.00021873 ... 0.00012943 0.00032854 0.00086526]\n",
      " [0.00022673 0.00034859 0.00021873 ... 0.00012943 0.00032854 0.00086526]\n",
      " [0.00022673 0.00034859 0.00021873 ... 0.00012943 0.00032854 0.00086526]]\n",
      "Finished!\n",
      "\n",
      "Starting mobilenetv2_1.00_224 tmp_savedmodels/4\n",
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_4_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/4_FP32/assets\n",
      "[[1.8110585e-04 6.4528472e-04 6.8695762e-04 ... 7.9570833e-05\n",
      "  1.3486181e-04 3.3463116e-03]\n",
      " [1.8110585e-04 6.4528472e-04 6.8695762e-04 ... 7.9570833e-05\n",
      "  1.3486181e-04 3.3463116e-03]\n",
      " [1.8110585e-04 6.4528472e-04 6.8695762e-04 ... 7.9570833e-05\n",
      "  1.3486181e-04 3.3463116e-03]\n",
      " ...\n",
      " [1.8110585e-04 6.4528472e-04 6.8695762e-04 ... 7.9570833e-05\n",
      "  1.3486181e-04 3.3463116e-03]\n",
      " [1.8110585e-04 6.4528472e-04 6.8695762e-04 ... 7.9570833e-05\n",
      "  1.3486181e-04 3.3463116e-03]\n",
      " [1.8110585e-04 6.4528472e-04 6.8695762e-04 ... 7.9570833e-05\n",
      "  1.3486181e-04 3.3463116e-03]]\n",
      "Finished!\n",
      "\n",
      "Starting densenet121 tmp_savedmodels/5\n",
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_5_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/5_FP32/assets\n",
      "[[2.3581024e-04 3.7533988e-04 1.1308040e-04 ... 5.6219425e-05\n",
      "  2.6299071e-04 1.1581751e-03]\n",
      " [2.3581024e-04 3.7533988e-04 1.1308040e-04 ... 5.6219425e-05\n",
      "  2.6299071e-04 1.1581751e-03]\n",
      " [2.3581024e-04 3.7533988e-04 1.1308040e-04 ... 5.6219425e-05\n",
      "  2.6299071e-04 1.1581751e-03]\n",
      " ...\n",
      " [2.3581024e-04 3.7533988e-04 1.1308040e-04 ... 5.6219425e-05\n",
      "  2.6299071e-04 1.1581751e-03]\n",
      " [2.3581024e-04 3.7533988e-04 1.1308040e-04 ... 5.6219425e-05\n",
      "  2.6299071e-04 1.1581751e-03]\n",
      " [2.3581024e-04 3.7533988e-04 1.1308040e-04 ... 5.6219425e-05\n",
      "  2.6299071e-04 1.1581751e-03]]\n",
      "Finished!\n",
      "\n",
      "Starting resnet50v2 tmp_savedmodels/6\n",
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_6_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/6_FP32/assets\n",
      "[[0.00082353 0.00079469 0.00060477 ... 0.00036948 0.00069747 0.00154858]\n",
      " [0.00082353 0.00079469 0.00060477 ... 0.00036948 0.00069747 0.00154858]\n",
      " [0.00082353 0.00079469 0.00060477 ... 0.00036948 0.00069747 0.00154858]\n",
      " ...\n",
      " [0.00082353 0.00079469 0.00060477 ... 0.00036948 0.00069747 0.00154858]\n",
      " [0.00082353 0.00079469 0.00060477 ... 0.00036948 0.00069747 0.00154858]\n",
      " [0.00082353 0.00079469 0.00060477 ... 0.00036948 0.00069747 0.00154858]]\n",
      "Finished!\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from helper import ModelOptimizer\n",
    "\n",
    "opt_models = []\n",
    "for model_class, model, dummy in zip(models, model_dirs, dummy_inputs):\n",
    "    print(\"Starting\", model_class._name, model)\n",
    "    model_opt = ModelOptimizer(model)\n",
    "    opt_trt = model_opt.convert(model+'_'+PRECISION, precision=PRECISION)\n",
    "\n",
    "    print(opt_trt.predict(dummy))\n",
    "    \n",
    "    opt_models.append(opt_trt)\n",
    "    \n",
    "    print(\"Finished!\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. What TensorRT runtime am I targeting?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will stay inside our Tensorflow/Python runtime:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.00082353, 0.00079469, 0.00060477, ..., 0.00036948, 0.00069747,\n",
       "        0.00154858],\n",
       "       [0.00082353, 0.00079469, 0.00060477, ..., 0.00036948, 0.00069747,\n",
       "        0.00154858],\n",
       "       [0.00082353, 0.00079469, 0.00060477, ..., 0.00036948, 0.00069747,\n",
       "        0.00154858],\n",
       "       ...,\n",
       "       [0.00082353, 0.00079469, 0.00060477, ..., 0.00036948, 0.00069747,\n",
       "        0.00154858],\n",
       "       [0.00082353, 0.00079469, 0.00060477, ..., 0.00036948, 0.00069747,\n",
       "        0.00154858],\n",
       "       [0.00082353, 0.00079469, 0.00060477, ..., 0.00036948, 0.00069747,\n",
       "        0.00154858]], dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "opt_models[idx].predict(dummy_inputs[idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance Comparisons:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = 0 #resnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 160 ms, sys: 5.52 ms, total: 166 ms\n",
      "Wall time: 148 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[1.69642386e-04, 3.30075040e-04, 6.13506127e-05, ...,\n",
       "        1.46224065e-05, 1.44499005e-04, 6.60870341e-04],\n",
       "       [1.69642386e-04, 3.30075040e-04, 6.13506127e-05, ...,\n",
       "        1.46224065e-05, 1.44499005e-04, 6.60870341e-04],\n",
       "       [1.69642386e-04, 3.30075040e-04, 6.13506127e-05, ...,\n",
       "        1.46224065e-05, 1.44499005e-04, 6.60870341e-04],\n",
       "       ...,\n",
       "       [1.69642386e-04, 3.30075040e-04, 6.13506127e-05, ...,\n",
       "        1.46224065e-05, 1.44499005e-04, 6.60870341e-04],\n",
       "       [1.69642386e-04, 3.30075040e-04, 6.13506127e-05, ...,\n",
       "        1.46224065e-05, 1.44499005e-04, 6.60870341e-04],\n",
       "       [1.69642386e-04, 3.30075040e-04, 6.13506127e-05, ...,\n",
       "        1.46224065e-05, 1.44499005e-04, 6.60870341e-04]], dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "models[idx].predict(dummy_inputs[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 30.2 ms, sys: 8.3 ms, total: 38.5 ms\n",
      "Wall time: 36.6 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[1.6964252e-04, 3.3007402e-04, 6.1350249e-05, ..., 1.4622317e-05,\n",
       "        1.4449877e-04, 6.6086568e-04],\n",
       "       [1.6964252e-04, 3.3007402e-04, 6.1350249e-05, ..., 1.4622317e-05,\n",
       "        1.4449877e-04, 6.6086568e-04],\n",
       "       [1.6964252e-04, 3.3007402e-04, 6.1350249e-05, ..., 1.4622317e-05,\n",
       "        1.4449877e-04, 6.6086568e-04],\n",
       "       ...,\n",
       "       [1.6964252e-04, 3.3007402e-04, 6.1350249e-05, ..., 1.4622317e-05,\n",
       "        1.4449877e-04, 6.6086568e-04],\n",
       "       [1.6964252e-04, 3.3007402e-04, 6.1350249e-05, ..., 1.4622317e-05,\n",
       "        1.4449877e-04, 6.6086568e-04],\n",
       "       [1.6964252e-04, 3.3007402e-04, 6.1350249e-05, ..., 1.4622317e-05,\n",
       "        1.4449877e-04, 6.6086568e-04]], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "opt_models[idx].predict(dummy_inputs[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = -3 # mobilenets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 105 ms, sys: 14.4 ms, total: 120 ms\n",
      "Wall time: 63.5 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[1.8110899e-04, 6.4530974e-04, 6.8695901e-04, ..., 7.9570033e-05,\n",
       "        1.3486811e-04, 3.3462986e-03],\n",
       "       [1.8110899e-04, 6.4530974e-04, 6.8695901e-04, ..., 7.9570033e-05,\n",
       "        1.3486811e-04, 3.3462986e-03],\n",
       "       [1.8110899e-04, 6.4530974e-04, 6.8695901e-04, ..., 7.9570033e-05,\n",
       "        1.3486811e-04, 3.3462986e-03],\n",
       "       ...,\n",
       "       [1.8110899e-04, 6.4530974e-04, 6.8695901e-04, ..., 7.9570033e-05,\n",
       "        1.3486811e-04, 3.3462986e-03],\n",
       "       [1.8110899e-04, 6.4530974e-04, 6.8695901e-04, ..., 7.9570033e-05,\n",
       "        1.3486811e-04, 3.3462986e-03],\n",
       "       [1.8110899e-04, 6.4530974e-04, 6.8695901e-04, ..., 7.9570033e-05,\n",
       "        1.3486811e-04, 3.3462986e-03]], dtype=float32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "models[idx].predict(dummy_inputs[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 19.9 ms, sys: 4.48 ms, total: 24.4 ms\n",
      "Wall time: 22.4 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[1.8110585e-04, 6.4528472e-04, 6.8695762e-04, ..., 7.9570833e-05,\n",
       "        1.3486181e-04, 3.3463116e-03],\n",
       "       [1.8110585e-04, 6.4528472e-04, 6.8695762e-04, ..., 7.9570833e-05,\n",
       "        1.3486181e-04, 3.3463116e-03],\n",
       "       [1.8110585e-04, 6.4528472e-04, 6.8695762e-04, ..., 7.9570833e-05,\n",
       "        1.3486181e-04, 3.3463116e-03],\n",
       "       ...,\n",
       "       [1.8110585e-04, 6.4528472e-04, 6.8695762e-04, ..., 7.9570833e-05,\n",
       "        1.3486181e-04, 3.3463116e-03],\n",
       "       [1.8110585e-04, 6.4528472e-04, 6.8695762e-04, ..., 7.9570833e-05,\n",
       "        1.3486181e-04, 3.3463116e-03],\n",
       "       [1.8110585e-04, 6.4528472e-04, 6.8695762e-04, ..., 7.9570833e-05,\n",
       "        1.3486181e-04, 3.3463116e-03]], dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "opt_models[idx].predict(dummy_inputs[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
